#!/usr/bin/env python

"""
__file__ = EDAFollowUP.py:

__author__ = "William Bosl"
__copyright__ = "Copyright 2022, William J. Bosl"
__credits__ = ["William Bosl"]
__license__ = All rights reserved by William J. Bosl
__version__ = "1.0.0"
__maintainer__ = "William Bosl"
__email__ = "wjbosl@gmail.com"
__status__ = "Initial test"
"""
#
# See scikit-learn documentation for more information
# https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation
#
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn.feature_selection import RFECV
from sklearn.model_selection import permutation_test_score


#
"""
You can give the column name for the labels (first name) and 
all features here
"""

all_columns = ['Group', 'Sex', 'Age'
               , 'A_EDA', 'M_EDA', 'A_TEMP', 'A_HR', 'M_HR', 'M_TEMP'
               , 'Age1Seizure', 'MRIFindings', 'ASM_Red'
               , 'NormalEEG', 'FocalSlowing', 'GeneralizedSlowing', 'Spikes']

# all_columns = ['Group', 'Sex', 'Age'
#               , 'Age1Seizure', 'MRIFindings', 'ASM_Red'
#               , 'NormalEEG', 'FocalSlowing', 'GeneralizedSlowing', 'Spikes']

# all_columns = ['Group', 'Sex', 'Age'
#               , 'A_EDA', 'M_EDA', 'A_TEMP', 'A_HR', 'M_HR', 'M_TEMP']


# 'EtiolNum',
# Number of random shuffles for comparison (default is zero)

NSHUFFLES = 200
TEST = False
noise_sd = 0.5


def classify(X, Y, nr=0):
    # ----------  this is not required ---------------
    # This is an example of how to do cross validation manually with each of the training and test sets.
    # The line above will do the cross validation for you.
    # for train_index, test_index in skf:
    #    X_train, X_test = X[train_index], X[test_index]
    #    y_train, y_test = y[train_index], y[test_index]
    #    clf.fit(X_train, y_train)
    # -------------------------------------------------

    # Let's normalize X
    scaler = preprocessing.StandardScaler().fit(X)
    X = scaler.transform(X)
    # This list of names is used for printing results. Comment out names not used, and the
    # corresponding classifiers in the next list

    names = ["LogisticRegression", "Nearest Neighbors", "Random Forest",
             "Ada Boost", "Naive Bayes", "Linear SVM", "RBF SVM"]

    Y_truth = Y.copy()
    N = len(Y_truth)
    sqrtN = np.sqrt(1.0*N)

    # This is the list of classifiers. More are possible. See scikit-learn documentation.
    classifiers = [
        LogisticRegression(),
        KNeighborsClassifier(),
        # RandomForestClassifier(max_depth=60, n_estimators=15, max_features=2),    # 60 15 5
        RandomForestClassifier(),
        AdaBoostClassifier(),                                   # learning_rate=0.08
        GaussianNB(),                                              # var_smoothing=4
        SVC(kernel="linear", probability=True),                                   # , C=0.1, probability=True, max_iter=10000,  gamma='auto', break_ties=True
        SVC(gamma='auto',kernel='rbf', probability=True)]                                  # , C=0.025)] , gamma='auto', probability=True)


    # iterate over classifiers. Choose either k-fold or leave-one-out cross validation

    kfolds = 10

    # Print some headers
    count_Y = np.bincount(Y)
    nsubjects = len(Y)
    nseizures = count_Y[1]
    print("RESULTS for: %d-fold CV, %d random shuffles." % (kfolds, nr))
    print("N subjects = %d, n seizure cases = %d" % (nsubjects, nseizures))
    print()
    print("%18s   %8s %6s %6s %6s %12s %6s %16s %6s" % (
    "Classifier", "Accuracy", "Sens", "Spec", "Brier (CI-,CI+)", " AU ROC", "(CI-,CI+)", " Rand_acc(sd)", "P-value"))

    for name, clf in zip(names, classifiers):

        Y = Y_truth.copy()
        
        #scores = cross_val_score(clf, X, Y, cv=kfolds, scoring='accuracy')
        
        # Scoring metrics: Area Under ROC curve and Brier score
        roc_auc = cross_val_score(clf, X, Y, cv=kfolds, scoring='roc_auc')
        brier = cross_val_score(clf, X, Y, cv=kfolds, scoring='neg_brier_score')
        
        # Compute 95% confidence intervals
        b = -brier.mean()
        b_err = 1.96*brier.std()/sqrtN
        mean_auc = roc_auc.mean()
        auc_err = 1.96*roc_auc.std()/sqrtN
        b_CI1 = b - b_err
        b_CI2 = b + b_err
        au_CI1 = mean_auc - auc_err
        au_CI2 = mean_auc + auc_err
        y_pred = cross_val_predict(clf, X, Y, cv=kfolds, method='predict')
        tn, fp, fn, tp = confusion_matrix(Y_truth, y_pred).ravel()
        sens = tp / (tp + fn)
        spec = tn / (tn + fp)
        acc = (tp + tn) / (tn + fp + fn + tp)
#        y_prob = cross_val_predict(clf, X, Y, cv=kfolds, method='predict_proba')[:, 1]
            
        #----------------
        score_rand, rand_means, pvalue_rand = permutation_test_score(
            clf, X, Y, scoring="accuracy", cv=kfolds, n_permutations=NSHUFFLES)

        #----------------

        truth = acc
        rand_mean = np.mean(rand_means)
        rand_sd = np.std(rand_means)
        p = pvalue_rand

         # The mean score and the 95% confidence interval of the score estimate are hence given by:
        print("%18s   %8.2f %6.2f %6.2f %6.2f (%4.2f,%4.2f) %6.2f (%4.2f,%4.2f); %6.2f (%5.3f)  %6.3e" % (
        name, truth, sens, spec, b,b_CI1,b_CI2, mean_auc, au_CI1, au_CI2, rand_mean, rand_sd, p))

    # Recursive feature elimination
    print("\nRecursive feature elimination")
    svc = SVC(kernel="linear")
    # The "accuracy" scoring is proportional to the number of correct
    # classifications

    min_features_to_select = 1  # Minimum number of features to consider
    rfecv = RFECV(estimator=svc, step=1, cv=StratifiedKFold(kfolds),
                  scoring='accuracy',
                  min_features_to_select=min_features_to_select)
    rfecv.fit(X, Y)
    print("Optimal number of features : %d" % rfecv.n_features_)
    print("grid_scores_: ", rfecv.grid_scores_)
    print("ranking_: ", rfecv.ranking_)
    print("features: ", all_columns[1:])

    # Plot number of features VS. cross-validation scores
    plt.figure()
    plt.xlabel("Number of features selected")
    plt.ylabel("Cross validation score (nb of correct classifications)")
    plt.plot(range(min_features_to_select,
                   len(rfecv.grid_scores_) + min_features_to_select),
             rfecv.grid_scores_)
    plt.show()


if __name__ == "__main__":
    # Get the command line arguments. Give the user instructions if needed.
    argv = sys.argv  # This is a list of strings with all command line inputs
    if len(argv) < 2:  # not enough arguments
        print("Usage: python cv.py filename.csv")
        sys.exit()
    else:  # looks good! Get the filename for input
        filename = argv[1]  # Python indices start at 0.

    # Let's make sure white space at the end of NaN's won't confuse the reader
    additional_nans = ['NaN ', 'nan ', 'na ', 'inf', 'inf ', -1]

    # load the data file as a dataframe.
    df_orig = pd.read_csv(filename, skipinitialspace=True, na_values=additional_nans)
    # Print the column names that were read, for fun
    # print(df_orig.columns)
    # print("Number of subjects = ", len(df_orig))
    # Keep only the columns that will be used; put the label (Y) column first
    df = df_orig[all_columns]
    # Remove any rows with missing values
    df = df.dropna(inplace=False)

    # Let's convert columns with strings to integers
    if 'side' in all_columns:
        df.loc[([x == 'right' for x in df.side]), 'side'] = 1
        df.loc[([x == 'left' for x in df.side]), 'side'] = 2
    if 'pos' in all_columns:
        df.loc[([x == 'wrist' for x in df.pos]), 'pos'] = 1
        df.loc[([x == 'ankle' for x in df.pos]), 'pos'] = 2

    # Get the labels
    # Y = df.seizure_after_9pm.to_numpy(dtype=int, copy=True)
    Y = df[all_columns[0]].to_numpy(dtype=int, copy=True)
    # Get the features by selecting the columns of interest.
    # X_columns = ['HR', 'RMSSD', 'EDA','Age', 'interictal_EEG', 'side', 'pos', 'pos_0/1']
    X_columns = all_columns[1:]
    df_X = df[X_columns]
    X = df_X.to_numpy(copy=True)

    ################  TEST  #######################
    # Create a dummy feature that's close to the
    # true labels by adding random noise to Y
    if TEST:
        n = len(Y)
        noisy_Y = Y.reshape((n, 1)) + np.random.normal(loc=0.0, scale=noise_sd, size=(n, 1))
        all_columns.append('test')
        X = np.hstack((X, noisy_Y))

    ################  TEST  #######################

    # Try classification using cross-validation
    print()
    print("Labels for classification: ", all_columns[0], ", with values: ", df[all_columns[0]].unique())
    print("Features: ", all_columns[1:])
    classify(X, Y, NSHUFFLES)
    # X is an n x m array, where n is the number of rows (subjects) and m are all of the features or metrics
    # Y is an array of length n, with one entry for each subject. These are the group labels.
    #
    # Fill Y and X with the data;
    # data from 15 min analysis ordered by pateints ID; sz before y/n, EDA mean, Hr mean RMSSD mean
